{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('dataset-bpe.json') as fopen:\n",
    "    data = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X = data['train_X']\n",
    "train_Y = data['train_Y']\n",
    "test_X = data['test_X']\n",
    "test_Y = data['test_Y']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "EOS = 2\n",
    "GO = 1\n",
    "vocab_size = 32000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_Y = [i + [2] for i in train_Y]\n",
    "test_Y = [i + [2] for i in test_Y]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensor2tensor.utils import beam_search\n",
    "\n",
    "def pad_second_dim(x, desired_size):\n",
    "    padding = tf.tile([[[0.0]]], tf.stack([tf.shape(x)[0], desired_size - tf.shape(x)[1], tf.shape(x)[2]], 0))\n",
    "    return tf.concat([x, padding], 1)\n",
    "\n",
    "class Translator:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size, learning_rate,\n",
    "                beam_width = 5):\n",
    "        \n",
    "        def cell(size, residual, reuse=False):\n",
    "            c = tf.nn.rnn_cell.LSTMCell(size,initializer=tf.orthogonal_initializer(),reuse=reuse)\n",
    "            if residual:\n",
    "                c = tf.nn.rnn_cell.ResidualWrapper(c)\n",
    "            return c\n",
    "        \n",
    "        def cells(size = size_layer, residual = 1, reuse=False):\n",
    "            cell_list = []\n",
    "            for i in range(num_layers):\n",
    "                cell_list.append(cell(size, i >= residual, reuse=reuse))\n",
    "            return cell_list\n",
    "        \n",
    "        def attention(encoder_out, seq_len, reuse=False):\n",
    "            attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_units = size_layer, \n",
    "                                                                    memory = encoder_out,\n",
    "                                                                    memory_sequence_length = seq_len)\n",
    "            return tf.contrib.seq2seq.AttentionWrapper(\n",
    "            cell = tf.nn.rnn_cell.MultiRNNCell(cells(reuse=reuse)), \n",
    "                attention_mechanism = attention_mechanism,\n",
    "                attention_layer_size = size_layer)\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        \n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype = tf.int32)\n",
    "        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype = tf.int32)\n",
    "        batch_size = tf.shape(self.X)[0]\n",
    "        \n",
    "        embeddings = tf.Variable(tf.random_uniform([vocab_size, embedded_size], -1, 1))\n",
    "        \n",
    "        encoder_out, encoder_state = tf.nn.dynamic_rnn(\n",
    "            cell = tf.nn.rnn_cell.MultiRNNCell(cells()), \n",
    "            inputs = tf.nn.embedding_lookup(embeddings, self.X),\n",
    "            sequence_length = self.X_seq_len,\n",
    "            dtype = tf.float32)\n",
    "        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])\n",
    "        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)\n",
    "        dense = tf.layers.Dense(vocab_size)\n",
    "        decoder_cells = attention(encoder_out, self.X_seq_len)\n",
    "        \n",
    "        states = decoder_cells.zero_state(batch_size, tf.float32).clone(cell_state=encoder_state)\n",
    "        \n",
    "        training_helper = tf.contrib.seq2seq.TrainingHelper(\n",
    "                inputs = tf.nn.embedding_lookup(embeddings, decoder_input),\n",
    "                sequence_length = self.Y_seq_len,\n",
    "                time_major = False)\n",
    "        training_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = decoder_cells,\n",
    "                helper = training_helper,\n",
    "                initial_state = states,\n",
    "                output_layer = dense)\n",
    "        training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = training_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = tf.reduce_max(self.Y_seq_len))\n",
    "        self.training_logits = training_decoder_output.rnn_output\n",
    "        \n",
    "        predicting_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(\n",
    "                embedding = embeddings,\n",
    "                start_tokens = tf.tile(tf.constant([GO], dtype=tf.int32), [batch_size]),\n",
    "                end_token = EOS)\n",
    "        predicting_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = decoder_cells,\n",
    "                helper = predicting_helper,\n",
    "                initial_state = states,\n",
    "                output_layer = dense)\n",
    "        predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = predicting_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = 2 * tf.reduce_max(self.X_seq_len))\n",
    "        self.fast_result = predicting_decoder_output.sample_id\n",
    "        \n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,\n",
    "                                                     targets = self.Y,\n",
    "                                                     weights = masks)\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(self.cost)\n",
    "        y_t = tf.argmax(self.training_logits,axis=2)\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(y_t, masks)\n",
    "        mask_label = tf.boolean_mask(self.Y, masks)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 512\n",
    "num_layers = 2\n",
    "embedded_size = 256\n",
    "learning_rate = 1e-3\n",
    "batch_size = 128\n",
    "epoch = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-10-f9e7759d7799>:42: MultiRNNCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.StackedRNNCells, and will be replaced by that in Tensorflow 2.0.\n",
      "WARNING:tensorflow:From <ipython-input-10-f9e7759d7799>:45: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `keras.layers.RNN(cell)`, which is equivalent to this API\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:958: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.add_weight` method instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:962: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/client/session.py:1750: UserWarning: An interactive session is already active. This can cause out-of-memory errors in some cases. You must explicitly call `InteractiveSession.close()` to release resources held by the other session(s).\n",
      "  warnings.warn('An interactive session is already active. This can '\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn.py:244: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Translator(size_layer, num_layers, embedded_size, learning_rate)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "pad_sequences = tf.keras.preprocessing.sequence.pad_sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([[29047, 30444, 14561,  2230,  5786, 26561, 31618, 31001, 31001,\n",
       "         17966,  2667, 18957, 28548, 31477, 31477, 11185, 24211, 14213,\n",
       "         31284, 22522, 22781, 29498, 29498, 25449,  7342,  7342,  9649,\n",
       "          9649, 19618,  9649,  9649, 19618, 19721,  9649, 17234, 17234,\n",
       "         29692, 29220, 18658, 31931, 18658,   918,  7650,  7650,   918,\n",
       "         12554, 12554,  7166,  7166, 10922, 10922, 19621,  9684, 26497,\n",
       "         11469, 11469,  8126, 11469, 29371,  3538,  7402, 31331,  4123,\n",
       "          4626,  3462, 24827, 21869, 22296, 16138,  1841, 30793, 27284],\n",
       "        [ 4497, 22485, 22485,  7264,  7264, 27848,  9481,  7264,  7627,\n",
       "         12491,  9481, 12491,    60,    60,   952,   950,  7049, 12491,\n",
       "         11850, 12530,  8881,  6790, 24012, 17025, 22374, 30923, 15264,\n",
       "         26414, 10909, 10610, 10610, 10610, 10610, 30261,  3202, 30694,\n",
       "         30694, 27625, 12202,  5254, 10782,   602, 25117, 28548,  3751,\n",
       "         26891,  2449, 23919, 23919, 30125,   788,   788,   788, 23986,\n",
       "         13873,  4432,  7699, 16419,  5933, 10076, 15445,  7699,  8051,\n",
       "          8051,  8051,  1001,  8051,  6632,  6632, 24899,  4096, 30416],\n",
       "        [27600, 31046, 15655, 10677, 27485, 17750, 23156, 29801,  6619,\n",
       "          3613, 20412, 12057,  2108, 21962, 11808, 16883, 16883, 23269,\n",
       "         19531, 19531, 28856,  4047, 24010, 20950,  4507,  4811, 11449,\n",
       "         11449, 24507, 24507,  6381,  5922, 17041, 11177,  9716, 27829,\n",
       "          8384,  8384,   264,  8061,  8061,  8061, 15063, 15756, 10785,\n",
       "         15756, 22878,  7502, 15996, 19184,  5900,  4267,  6479, 15946,\n",
       "         20950,  1893, 30627, 24507, 24507, 24507,  4400, 20908, 21159,\n",
       "         21159, 21159, 19483, 10677, 28432,  3372, 10534, 28432, 18325],\n",
       "        [12203, 18996, 12409,  3728, 28713,   974, 16122, 31624,  7188,\n",
       "          3281,  5512,  5512, 20937,  3771, 20468, 25068, 30387,  4656,\n",
       "         15289, 13059, 31830, 24046, 30315, 30315, 26249, 25746, 18810,\n",
       "         29699,  9023,  9362,  9023, 11538, 31371, 16471, 16471,  7921,\n",
       "         26414, 31371, 24578, 22878, 15967, 28207,  9362, 12104, 12104,\n",
       "         21523,  6658, 10229, 22323, 31734, 13204, 10534,  4237, 28200,\n",
       "         28200,  6381, 28200,  7821,  3177, 26065, 28200, 30670, 28200,\n",
       "         16000, 16000, 20426,  1575, 15335, 30038, 30038,  3336, 27314],\n",
       "        [13686, 17492, 11617, 11874,  2976,  2976, 20012, 20012, 20012,\n",
       "          1388, 12356,  8616, 22590, 27891,   398,  9669, 27891,  9669,\n",
       "         30034, 30034, 27712, 29167, 29167, 26446, 29167, 26446, 29167,\n",
       "         16921, 11874, 16921, 11874,  2976,  2976, 20768, 26796,   398,\n",
       "         14140,  8675,  3642, 11806, 14191,  5071,  7559, 29590,  5071,\n",
       "         21494, 25894, 17058, 25894, 15782, 10105, 15782,  8381,  8381,\n",
       "          7026, 13375,  9622, 13375, 13375, 12285, 12285, 26249,  8521,\n",
       "         12090,  9165, 21815, 12078, 12078, 19783, 25869, 10050, 13126],\n",
       "        [22945, 14731, 11178,  3654, 20781, 26710, 16267, 29306, 10982,\n",
       "         28381,  8969,   291, 29901, 21452, 21452, 21224, 14876, 14365,\n",
       "         14876, 14365, 20859, 17800, 16687, 19731, 28533, 28533, 28533,\n",
       "         28533, 28533,  6168, 22572,  5254,   291,   291, 11386, 17170,\n",
       "         13800, 15713, 18884,  9239, 14699, 14699, 29769,  1522, 14699,\n",
       "         29769,  1522,   482, 14570, 14570, 13912, 25764, 30482, 30987,\n",
       "         20567, 22999, 10553, 18381, 27172,  7170, 22945, 26476, 26476,\n",
       "          8434,  8434, 19247, 15831, 15831, 30940, 31544, 16791,  9225],\n",
       "        [ 3466, 25681,   208,   157, 29323,  1590, 25711,  8719,  3466,\n",
       "         23595, 28953, 11591, 19537, 21211, 25979, 13889,  7861, 18367,\n",
       "          4498, 12837,  4498, 12837, 18798, 12837, 13188, 31843, 31843,\n",
       "         11664,  1982,  7322,  7322,  5844,  5844,  3576,  3576,  3695,\n",
       "          7694,  1434, 17334, 14267, 13756,  2447, 20441, 22909, 22909,\n",
       "           499, 19522,   499, 19522, 16526,  2098, 16526, 21833,  9075,\n",
       "         16526, 27592,  2153, 27592,  1217, 16403,  6350, 24486,  5396,\n",
       "          5396,  6641,  7024, 12966, 14725, 15103, 21713,  2293, 15776],\n",
       "        [22319, 18195, 20505,  8671, 12086,  7314, 26546,  9895, 19457,\n",
       "         13969, 29995, 12627, 12627, 17109, 17109, 17109, 26245, 23099,\n",
       "         13119,  2844, 13119, 17261, 25674, 25674, 19698,   550, 25720,\n",
       "          6291, 29417, 13609, 29417,  4896,  1632,  2591,  9692, 25012,\n",
       "         25956,  6682,  6682, 27489, 27393,  9788, 22855,  4886, 30813,\n",
       "          4896, 30813, 26792, 24484, 20746, 15829, 31568, 11678, 21187,\n",
       "         30125, 19829, 28550, 11747, 11747, 22008, 11747, 25970, 22008,\n",
       "         29645,  4829, 18421,  6507, 10083,  1061,  2219,  2219,  2777],\n",
       "        [   95, 28921, 10637,  1132,  7276, 20128, 20128,  5244, 15624,\n",
       "         14703, 15624, 17716, 17716,  2400,  2400, 31210, 14543,  7864,\n",
       "         19929, 30050, 14663,  4585,  4585,  3062,  6381,  6381,    42,\n",
       "          6511,  6511, 23585, 28268,  3182,  7383, 19835, 19835,  4645,\n",
       "          2983, 12526,   839, 17420,   321,  3637,  3637,  3637,  4560,\n",
       "          3637,  4560,  3637,  4560,  2799, 14897,    42, 11564, 17399,\n",
       "           448, 10749, 10749, 21618,  4630, 25664, 20977, 20015,  6139,\n",
       "         25811, 24246, 25811, 21178, 30695, 30695, 18636, 21255,  1109],\n",
       "        [26750,  4249, 25269, 16495, 16495, 15917, 28768, 29954, 27775,\n",
       "         12419,   109, 23109, 15289, 19383, 16502, 16502, 16502, 18788,\n",
       "         16323, 25332,  8173, 17275,  5936, 26782,   746,  3718, 22759,\n",
       "         26978, 15025, 14553, 14553, 31660, 29277, 24142,  1462,  4841,\n",
       "          4841,  8717, 27579, 29360, 22154, 22154, 30439, 30439, 30439,\n",
       "         31972, 25211, 24329, 16938, 14171, 16938, 23528, 30439, 30439,\n",
       "          4501, 31972, 16938,  6580, 29447, 23528, 30439, 30439,  6793,\n",
       "         15412, 28569, 28569, 25183, 25183, 18715, 25183,  2876,  2876]],\n",
       "       dtype=int32), 10.373421, 0.0]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_x = pad_sequences(train_X[:10], padding='post')\n",
    "batch_y = pad_sequences(train_Y[:10], padding='post')\n",
    "\n",
    "sess.run([model.fast_result, model.cost, model.accuracy], \n",
    "         feed_dict = {model.X: batch_x, model.Y: batch_y})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:57<00:00,  2.38it/s, accuracy=0.33, cost=4.1]  \n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.86it/s, accuracy=0.371, cost=3.73]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, training avg loss 5.061681, training avg acc 0.239105\n",
      "epoch 1, testing avg loss 3.912996, testing avg acc 0.348551\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [11:04<00:00,  2.35it/s, accuracy=0.448, cost=3.07]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.93it/s, accuracy=0.473, cost=2.96]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2, training avg loss 3.417761, training avg acc 0.405378\n",
      "epoch 2, testing avg loss 3.194276, testing avg acc 0.434877\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [11:04<00:00,  2.35it/s, accuracy=0.53, cost=2.46] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.93it/s, accuracy=0.484, cost=2.7] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3, training avg loss 2.765453, training avg acc 0.482058\n",
      "epoch 3, testing avg loss 2.859178, testing avg acc 0.477126\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [11:04<00:00,  2.35it/s, accuracy=0.58, cost=2.02] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.94it/s, accuracy=0.532, cost=2.52]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 4, training avg loss 2.373655, training avg acc 0.532350\n",
      "epoch 4, testing avg loss 2.728004, testing avg acc 0.496743\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [11:04<00:00,  2.35it/s, accuracy=0.639, cost=1.69]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.96it/s, accuracy=0.548, cost=2.52]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 5, training avg loss 2.097324, training avg acc 0.570312\n",
      "epoch 5, testing avg loss 2.693827, testing avg acc 0.506010\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [11:01<00:00,  2.36it/s, accuracy=0.678, cost=1.43]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.93it/s, accuracy=0.559, cost=2.41]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 6, training avg loss 1.885408, training avg acc 0.601276\n",
      "epoch 6, testing avg loss 2.712333, testing avg acc 0.508490\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [11:04<00:00,  2.35it/s, accuracy=0.707, cost=1.25]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.95it/s, accuracy=0.57, cost=2.45] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 7, training avg loss 1.720545, training avg acc 0.626045\n",
      "epoch 7, testing avg loss 2.767299, testing avg acc 0.508405\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [11:04<00:00,  2.35it/s, accuracy=0.733, cost=1.12]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.94it/s, accuracy=0.548, cost=2.57]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 8, training avg loss 1.574601, training avg acc 0.649324\n",
      "epoch 8, testing avg loss 2.842505, testing avg acc 0.505702\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [11:04<00:00,  2.35it/s, accuracy=0.769, cost=0.952]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.92it/s, accuracy=0.532, cost=2.57]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 9, training avg loss 1.447412, training avg acc 0.670523\n",
      "epoch 9, testing avg loss 2.921872, testing avg acc 0.502153\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [11:04<00:00,  2.35it/s, accuracy=0.786, cost=0.846]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.93it/s, accuracy=0.532, cost=2.72]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 10, training avg loss 1.342741, training avg acc 0.688246\n",
      "epoch 10, testing avg loss 3.028769, testing avg acc 0.497316\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [11:05<00:00,  2.35it/s, accuracy=0.813, cost=0.763]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.92it/s, accuracy=0.565, cost=2.69]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 11, training avg loss 1.253272, training avg acc 0.703593\n",
      "epoch 11, testing avg loss 3.102829, testing avg acc 0.497557\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [11:05<00:00,  2.35it/s, accuracy=0.825, cost=0.719]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.94it/s, accuracy=0.565, cost=2.72]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 12, training avg loss 1.172357, training avg acc 0.717938\n",
      "epoch 12, testing avg loss 3.178466, testing avg acc 0.496973\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [11:04<00:00,  2.35it/s, accuracy=0.828, cost=0.668]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.94it/s, accuracy=0.522, cost=2.89]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 13, training avg loss 1.101403, training avg acc 0.730972\n",
      "epoch 13, testing avg loss 3.291393, testing avg acc 0.493018\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [11:05<00:00,  2.35it/s, accuracy=0.853, cost=0.603]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.94it/s, accuracy=0.527, cost=2.97]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 14, training avg loss 1.041070, training avg acc 0.742019\n",
      "epoch 14, testing avg loss 3.370171, testing avg acc 0.490531\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [11:05<00:00,  2.35it/s, accuracy=0.847, cost=0.583]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.95it/s, accuracy=0.516, cost=3.04]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 15, training avg loss 0.990729, training avg acc 0.751395\n",
      "epoch 15, testing avg loss 3.461201, testing avg acc 0.487935\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [11:05<00:00,  2.35it/s, accuracy=0.87, cost=0.534] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.92it/s, accuracy=0.565, cost=2.94]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 16, training avg loss 0.940511, training avg acc 0.760863\n",
      "epoch 16, testing avg loss 3.535542, testing avg acc 0.486372\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [11:05<00:00,  2.35it/s, accuracy=0.873, cost=0.487]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.96it/s, accuracy=0.538, cost=3.22]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 17, training avg loss 0.900997, training avg acc 0.768394\n",
      "epoch 17, testing avg loss 3.626970, testing avg acc 0.482795\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [11:04<00:00,  2.35it/s, accuracy=0.87, cost=0.476] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.89it/s, accuracy=0.548, cost=3.06]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 18, training avg loss 0.859984, training avg acc 0.776790\n",
      "epoch 18, testing avg loss 3.706382, testing avg acc 0.484905\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [11:04<00:00,  2.35it/s, accuracy=0.889, cost=0.416]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.93it/s, accuracy=0.554, cost=3.28]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 19, training avg loss 0.827677, training avg acc 0.783131\n",
      "epoch 19, testing avg loss 3.791811, testing avg acc 0.483511\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [11:03<00:00,  2.35it/s, accuracy=0.893, cost=0.401]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.97it/s, accuracy=0.565, cost=3.15]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 20, training avg loss 0.793020, training avg acc 0.790365\n",
      "epoch 20, testing avg loss 3.864456, testing avg acc 0.483229\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import tqdm\n",
    "\n",
    "for e in range(epoch):\n",
    "    pbar = tqdm.tqdm(\n",
    "        range(0, len(train_X), batch_size), desc = 'minibatch loop')\n",
    "    train_loss, train_acc, test_loss, test_acc = [], [], [], []\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_X))\n",
    "        batch_x = pad_sequences(train_X[i : index], padding='post')\n",
    "        batch_y = pad_sequences(train_Y[i : index], padding='post')\n",
    "        feed = {model.X: batch_x,\n",
    "                model.Y: batch_y}\n",
    "        accuracy, loss, _ = sess.run([model.accuracy,model.cost,model.optimizer],\n",
    "                                    feed_dict = feed)\n",
    "        train_loss.append(loss)\n",
    "        train_acc.append(accuracy)\n",
    "        pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
    "    \n",
    "    \n",
    "    pbar = tqdm.tqdm(\n",
    "        range(0, len(test_X), batch_size), desc = 'minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_X))\n",
    "        batch_x = pad_sequences(test_X[i : index], padding='post')\n",
    "        batch_y = pad_sequences(test_Y[i : index], padding='post')\n",
    "        feed = {model.X: batch_x,\n",
    "                model.Y: batch_y,}\n",
    "        accuracy, loss = sess.run([model.accuracy,model.cost],\n",
    "                                    feed_dict = feed)\n",
    "\n",
    "        test_loss.append(loss)\n",
    "        test_acc.append(accuracy)\n",
    "        pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
    "    \n",
    "    print('epoch %d, training avg loss %f, training avg acc %f'%(e+1,\n",
    "                                                                 np.mean(train_loss),np.mean(train_acc)))\n",
    "    print('epoch %d, testing avg loss %f, testing avg acc %f'%(e+1,\n",
    "                                                              np.mean(test_loss),np.mean(test_acc)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensor2tensor.utils import bleu_hook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 40/40 [00:19<00:00,  2.10it/s]\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "for i in tqdm.tqdm(range(0, len(test_X), batch_size)):\n",
    "    index = min(i + batch_size, len(test_X))\n",
    "    batch_x = pad_sequences(test_X[i : index], padding='post')\n",
    "    feed = {model.X: batch_x}\n",
    "    p = sess.run(model.fast_result,feed_dict = feed)\n",
    "    result = []\n",
    "    for row in p:\n",
    "        result.append([i for i in row if i > 3])\n",
    "    results.extend(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "rights = []\n",
    "for r in test_Y:\n",
    "    rights.append([i for i in r if i > 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.1475228"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bleu_hook.compute_bleu(reference_corpus = rights,\n",
    "                       translation_corpus = results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
